import paddle
import numpy as np

a = [0, 0, 0, 0, 1, 0]

a = paddle.to_tensor(a, stop_gradient=False)

if __name__ == '__main__':
    c = paddle.tensor.arange(0, 24).reshape([2, 3, 4])
    print(c)
    d = paddle.tensor.index_select(c, index=paddle.to_tensor([1, 2]), axis=1)
    print(d)
    g = paddle.argmax(d, axis=2)
    o = paddle.sum(g, axis=1)
    print(paddle.count_nonzero(o - 6).item())
